{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "source": [
    "# MolMIM Property Guided Molecular Optimization Using CMA-ES\n",
    "\n",
    "Here we demonstrate how to load a [MolMIM](https://arxiv.org/abs/2208.09016) checkpoint from the BioNeMo Framework and use it to optimize some molecules of interest with a custom user-defined scoring function. We use [CMA-ES](https://en.wikipedia.org/wiki/CMA-ES) to traverse the latent space of our MolMIM model and select novel, related molecules expected to improve performance as measured by the scoring function. To sample these molecules, we must complete the following steps:\n",
    "\n",
    "1. Load the desired MolMIM checkpoint.\n",
    "2. Encode the starting molecules into MolMIM's latent space.\n",
    "3. Run CMA-ES, which will iteratively perform the following:\n",
    "    1. Decode latent representations into SMILES strings.\n",
    "    2. Apply the user defined scoring function to these SMILES strings to generate SMILES/scores pairings.\n",
    "    3. Ask the CMA-ES algorithm for a new set of latent space representations from which to sample.\n",
    "\n",
    "Note: this notebook is derived from [a previous tutorial made for the BioNeMo Service version of MolMIM](https://github.com/NVIDIA/BioNeMo/blob/main/examples/service/notebooks/cma_custom_oracles.ipynb).\n",
    "\n",
    "### Setup your environment for this test\n",
    "For this tutorial, we assume you are running within the latest BioNeMo Framework Docker container.\n",
    "\n",
    "From within the Docker container, download the example checkpoint, or use your own: \n",
    "\n",
    "```\n",
    "python download_models.py --download_dir models molmim_70m_24_3\n",
    "```\n",
    "\n",
    "### Load your checkpoint into the molmim inference wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "NGC_API_KEY = input(\"Enter your NGC API key: \")\n",
    "os.environ['NGC_API_KEY'] = NGC_API_KEY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "wget https://raw.githubusercontent.com/brevdev/notebooks/main/assets/setup-bionemo.sh -O setup-bionemo\n",
    "chmod +x setup-bionemo.sh\n",
    "./setup-bionemo.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [],
   "source": [
    "from bionemo.utils.hydra import load_model_config\n",
    "from bionemo.model.molecule.molmim.infer import MolMIMInference\n",
    "bionemo_home=f\"/workspace/bionemo\"\n",
    "os.environ['BIONEMO_HOME'] = bionemo_home\n",
    "checkpoint_path = f\"{bionemo_home}/models/molecule/molmim/molmim_70m_24_3.nemo\"\n",
    "cfg = load_model_config(config_name=\"molmim_infer.yaml\", config_path=f\"{bionemo_home}/examples/tests/conf/\") # reasonable starting config for molmim inference\n",
    "# This is the field of the config that we need to set to our desired checkpoint path.\n",
    "cfg.model.downstream_task.restore_from_path = checkpoint_path\n",
    "model = MolMIMInference(cfg, interactive=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup user-defined molecule scoring function\n",
    "This is the section where you as a user can pull in your own scoring functions that you want to optimize. For this example, we will be optimizing a combination of Tanimoto similarity to the input molecule and Quantitative Estimate of Druglikeness (QED) following the example from the initial [MolMIM publication](https://arxiv.org/abs/2208.09016):\n",
    "\n",
    "$$\n",
    "  score = min(QED / 0.9, 1) + min(Tanimoto / 0.4, 1)\n",
    "$$\n",
    "\n",
    "In this case, we will allow the model to optimize up to a maximum QED of 0.9 and Tanimoto similarity of 0.4. Once these maxima are achieved, we perform no further optimization. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Optional\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from guided_molecule_gen.oracles import qed, tanimoto_similarity\n",
    "\n",
    "def score_mixing_function(qeds, similarities):\n",
    "    # We want to maximize QED and tanimoto similarity up to 0.9 and 0.4, respectively.\n",
    "    return np.clip(qeds / 0.9, a_min=0.0, a_max=1.0) + np.clip(similarities / 0.4, a_min=0.0, a_max=1.0)\n",
    "\n",
    "def try_canon(smiles:str) -> Optional[str]:\n",
    "    try:\n",
    "        return Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True)\n",
    "    except:\n",
    "        return None\n",
    "\n",
    "def canonicalize(smiles: List[str]) -> List[str]:\n",
    "    return [try_canon(s) for s in smiles]\n",
    "\n",
    "\n",
    "def scoring_function(smiles: List[str], reference:str, **kwargs) -> np.ndarray:\n",
    "    \"\"\"Takes a list of SMILES strings and returns an array of scores.\n",
    "\n",
    "    Args:\n",
    "        smiles (List[str]): Smiles strings to generate a score for (one each)\n",
    "        reference (str): Reference molecule (SMILES string) is also used for this scoring function.\n",
    "\n",
    "    Returns:\n",
    "        np.ndarray: Array of scores, one for each input SMILES string.\n",
    "    \"\"\"\n",
    "    #csmiles = canonicalize(smiles)\n",
    "    scores: np.ndarray = score_mixing_function(qed(smiles), tanimoto_similarity(smiles, reference))\n",
    "    return -1 * scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define starting molecules\n",
    "In this section, we will define the starting molecules for the optimization process. As a set of examples, we will use imatinib, erlotinib, and gifitinib. We ensure that the SMILES strings representing these molecules are canonicalized using RDKit. MolMIM was trained on a corpus of RDKit-cononicalized SMILES strings, so any inputs and outputs should be RDKit-canonicalized as well to achieve peak performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rdkit import Chem\n",
    "from rdkit.Chem.QED import qed as rdkit_qed\n",
    "starting_smiles = [\n",
    "    \"CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=CC(=N4)C5=CN=CC=C5\", # imatinib\n",
    "    \"COCCOC1=C(C=C2C(=C1)C(=NC=N2)NC3=CC=CC(=C3)C#C)OCCOC\", # erlotinib\n",
    "    \"C1COCCN1CCCOc2c(OC)cc3ncnc(c3c2)Nc4cc(Cl)c(F)cc4\", # gifitinib\n",
    "]\n",
    "\n",
    "# Canonicalize all SMILES strings and print the structure of imatinib\n",
    "molecules = [Chem.MolFromSmiles(s) for s in starting_smiles]\n",
    "starting_qed = [rdkit_qed(m) for m in molecules]\n",
    "canonicalized_smiles = [Chem.MolToSmiles(m, canonical=True) for m in molecules]\n",
    "molecules[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup the optimizer and wrap the inference API for CMA-ES\n",
    "The CMA-ES library expects certain formats for input/output of the inference model to function properly. We provide a wrapper for this and show how to setup optimization below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bionemo.model.core.controlled_generation import ControlledGenerationPerceiverEncoderInferenceWrapper\n",
    "\n",
    "controlled_gen_kwargs = {\n",
    "    \"sampling_method\": \"beam-search\",\n",
    "    \"sampling_kwarg_overrides\": {\"beam_size\": 3, \"keep_only_best_tokens\": True, \"return_scores\": False},\n",
    "}\n",
    "\n",
    "model_wrapped = ControlledGenerationPerceiverEncoderInferenceWrapper(\n",
    "    model, enforce_perceiver=True, hidden_steps=1, **controlled_gen_kwargs\n",
    ")  # just flatten the position for this."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tune CMA-ES\n",
    "Different models will have different optimal settings for CMA-ES. Here, we perform a grid search over possible values of `sigma`, then perform more steps of optimization with the best. We will use the [Optuna library](https://optuna.org/) to perform this optimization over the `sigma` hyperparameter. This process is referred to as hyperparatemer optimization or HPO."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [],
   "source": [
    "!pip install optuna"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [],
   "source": [
    "from guided_molecule_gen.optimizer import MoleculeGenerationOptimizer\n",
    "import optuna\n",
    "\n",
    "def objective(trial, n_steps:int=10):\n",
    "    sigma = trial.suggest_float('sigma', 0, 2)\n",
    "    optimizer = MoleculeGenerationOptimizer(\n",
    "        model_wrapped,\n",
    "        scoring_function,\n",
    "        canonicalized_smiles,\n",
    "        popsize=10,  # larger values will be slower but more thorough\n",
    "        optimizer_args={\"sigma\": sigma},\n",
    "    )\n",
    "    optimizer.optimize(n_steps)\n",
    "    final_smiles = optimizer.generated_smis\n",
    "    final_score = np.mean([np.min(scoring_function(smis_population, reference_smis)) for smis_population,reference_smis in zip(final_smiles, canonicalized_smiles)])\n",
    "    return final_score\n",
    "\n",
    "study = optuna.create_study()\n",
    "study.optimize(objective, n_trials=50)\n",
    "print(study.best_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can examine the best performing value for `sigma` found by Optuna during HPO."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "study.best_params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Though the above value is the optimium returned over our HPO process, we will consider the range of valid values and pick a minimum that is more likely to be robust. Since the HPO process is stochastic, high-performing and low-performing values may be in close proximity. We would like to identify a good range of `sigma` values, over which the optimizer generally performs well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install statsmodels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import statsmodels.api as sm\n",
    "import pandas as pd\n",
    "completed_trials = [trial for trial in study.trials if trial.state == optuna.trial.TrialState.COMPLETE]\n",
    "trials_data = [{\"sigma\": trial.params[\"sigma\"], \"loss\": trial.value, \"trial_id\": tid} for tid,trial in enumerate(completed_trials)]\n",
    "data = pd.DataFrame(trials_data)\n",
    "\n",
    "# Now create a bootstrap confidence interval around the a LOWESS fit\n",
    "\n",
    "\n",
    "def lowess_with_confidence_bounds(\n",
    "    x, y, eval_x, N=200, conf_interval=0.95, lowess_kw=None\n",
    "):\n",
    "    \"\"\"\n",
    "    Perform Lowess regression and determine a confidence interval by bootstrap resampling\n",
    "    \"\"\"\n",
    "    # Lowess smoothing\n",
    "    #  x is called (exog), y is called (endog)\n",
    "    smoothed = sm.nonparametric.lowess(exog=x, endog=y, xvals=eval_x, **lowess_kw)\n",
    "\n",
    "    # Perform bootstrap resamplings of the data\n",
    "    # and  evaluate the smoothing at a fixed set of points\n",
    "    smoothed_values = np.empty((N, len(eval_x)))\n",
    "    for i in range(N):\n",
    "        sample = np.random.choice(len(x), len(x), replace=True)\n",
    "        sampled_x = x[sample]\n",
    "        sampled_y = y[sample]\n",
    "\n",
    "        smoothed_values[i] = sm.nonparametric.lowess(\n",
    "            exog=sampled_x, endog=sampled_y, xvals=eval_x, **lowess_kw\n",
    "        )\n",
    "\n",
    "    # Get the confidence interval\n",
    "    sorted_values = np.sort(smoothed_values, axis=0)\n",
    "    bound = int(N * (1 - conf_interval) / 2)\n",
    "    bottom = sorted_values[bound - 1]\n",
    "    top = sorted_values[-bound]\n",
    "\n",
    "    return smoothed, bottom, top\n",
    "\n",
    "\n",
    "# Compute the 95% confidence interval\n",
    "eval_x = np.linspace(0, 2, 200)\n",
    "smoothed, bottom, top = lowess_with_confidence_bounds(\n",
    "    data.sigma, data.loss, eval_x, lowess_kw={\"frac\": 0.33}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.scatter(x=data[\"sigma\"], y=data[\"loss\"], label=\"observed_points\")\n",
    "plt.plot(eval_x, smoothed, c=\"k\", label=\"lowess smoothed\")\n",
    "plt.fill_between(eval_x, bottom, top, alpha=0.5, color=\"b\", label=\"lowess 95% CI\")\n",
    "plt.legend()\n",
    "plt.title(\"Loss vs sigma from HPO for MolMIM model\")\n",
    "plt.autoscale(enable=True, axis=\"x\", tight=True);\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smoothed_best_sigma = eval_x[np.argmin(top)]  # Use the upper bound of the confidence interval\n",
    "smooth_best = {\"sigma\": smoothed_best_sigma}\n",
    "smooth_best"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now compare the smoothed top choice with the best nominal choice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smooth_best, study.best_params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run a larger CMA-ES optimization with discovered parameters\n",
    "Given the value of `sigma` we found to work well in our HPO above, we will increase the population size and steps and do a final larger optimizaiton run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [],
   "source": [
    "from tqdm import trange\n",
    "optimizer = MoleculeGenerationOptimizer(\n",
    "        model_wrapped,\n",
    "        scoring_function,\n",
    "        canonicalized_smiles,\n",
    "        popsize=50,  # larger values will be slower but more thorough\n",
    "        optimizer_args=smooth_best,  # Vals from HPO\n",
    "    )\n",
    "# Starting state for idx 0\n",
    "qed_scores = [qed(canonicalized_smiles)]\n",
    "tanimoto_scores = [[tanimoto_similarity([canonicalized_smiles[idx]], canonicalized_smiles[idx])[0] for idx in range(len(canonicalized_smiles))]]\n",
    "best_molecules = [canonicalized_smiles]\n",
    "fraction_bad_samples = [[0]*len(canonicalized_smiles)]\n",
    "for i in trange(30):\n",
    "    optimizer.step()\n",
    "    final_smiles = optimizer.generated_smis\n",
    "    # Population of molecules is returned, but we only want the best one.\n",
    "    _qed_scores = []\n",
    "    _tanimoto_scores = []\n",
    "    _best_molecules = []\n",
    "    _fraction_bad = []\n",
    "    for smis_population,reference_smis in zip(final_smiles, canonicalized_smiles):\n",
    "        idx = np.argmin(scoring_function(smis_population, reference_smis))\n",
    "        _fraction_bad.append(np.mean(qed(smis_population) == 0))\n",
    "        _best_molecules.append(smis_population[idx])\n",
    "        _qed_scores.append(qed([smis_population[idx]])[0])\n",
    "        _tanimoto_scores.append(tanimoto_similarity([smis_population[idx]], reference_smis)[0])\n",
    "    qed_scores.append(_qed_scores)\n",
    "    tanimoto_scores.append(_tanimoto_scores)\n",
    "    best_molecules.append(_best_molecules)\n",
    "    fraction_bad_samples.append(_fraction_bad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Explore results\n",
    "Below, we create a plot disaplaying how the components of our target (QED and Tanimoto similarity) changed over each iteration. By our target definition, any value above 0.4 for Tanimoto similarity would be optimal so we expect noise around that value. Similarly, for QED any value above 0.9 would be optimal so we expect noise around that value if any molecule surpasses that threshold."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "for i, molecule in enumerate([\"imatinib\", \"erlotinib\", \"gifitinib\"]):\n",
    "    line, = plt.plot(np.arange(len(qed_scores)), [q[i] for q in qed_scores], label=f\"{molecule} QED\")\n",
    "    color = line.get_color()\n",
    "    plt.plot(np.arange(len(tanimoto_scores)), [t[i] for t in tanimoto_scores], label=f\"{molecule} Tanimoto\", linestyle=\"--\", color=color)\n",
    "plt.axhline(y=0.9, color='r', linestyle='-', label=\"QED target\")\n",
    "plt.axhline(y=0.4, color='r', linestyle='--', label=\"Tanimoto target\")\n",
    "plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"QED or Tanimoto similarity\")\n",
    "plt.title(\"Targets over time for MolMIM model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How well did our optimization perform?\n",
    "To examine the performance of out optimization, we can quantify the number of invalid samples that were generated. An \"invalid\" SMILES is defined as a SMILES string that does not represent a chemically-valid underlying molecule."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(fraction_bad_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can finally quantify the improvement in QED over the baseline value and the fraction of our optimized molecules that maintained the desired Tanimoto similarity threshold above 0.4."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qed_improvements = []\n",
    "tanimoto_above_04 = []\n",
    "for i in range(len(starting_qed)):\n",
    "    tanimoto_above_04.append(tanimoto_scores[-1][i] >= 0.4)\n",
    "    qed_improvements.append(qed_scores[-1][i] - starting_qed[i])\n",
    "{\"mean_qed_improvement\": np.mean(qed_improvements), \"tanimoto_above_04\": np.mean(tanimoto_above_04)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
